{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 翻訳 (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# To run the training on TPU, you will need to uncomment the following line:\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"fr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 210173\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 189155\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 21018\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=20)\n",
    "split_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_datasets[\"validation\"] = split_datasets.pop(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'en': 'Default to expanded threads',\n",
       " 'fr': 'Par défaut, développer les fils de discussion'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets[\"train\"][1][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': 'Par défaut pour les threads élargis'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'en': 'Unable to import %1 using the OFX importer plugin. This file is not the correct format.',\n",
       " 'fr': \"Impossible d'importer %1 en utilisant le module d'extension d'importation OFX. Ce fichier n'a pas un format correct.\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets[\"train\"][172][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': \"Impossible d'importer %1 en utilisant le plugin d'importateur OFX. Ce fichier n'est pas le bon format.\"}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "en_sentence = split_datasets[\"train\"][1][\"translation\"][\"en\"]\n",
    "fr_sentence = split_datasets[\"train\"][1][\"translation\"][\"fr\"]\n",
    "\n",
    "inputs = tokenizer(en_sentence)\n",
    "with tokenizer.as_target_tokenizer():\n",
    "    targets = tokenizer(fr_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁Par', '▁dé', 'f', 'aut', ',', '▁dé', 've', 'lop', 'per', '▁les', '▁fil', 's', '▁de', '▁discussion', '</s>']\n",
       "['▁Par', '▁défaut', ',', '▁développer', '▁les', '▁fils', '▁de', '▁discussion', '</s>']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wrong_targets = tokenizer(fr_sentence)\n",
    "print(tokenizer.convert_ids_to_tokens(wrong_targets[\"input_ids\"]))\n",
    "print(tokenizer.convert_ids_to_tokens(targets[\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_input_length = 128\n",
    "max_target_length = 128\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = [ex[\"en\"] for ex in examples[\"translation\"]]\n",
    "    targets = [ex[\"fr\"] for ex in examples[\"translation\"]]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
    "\n",
    "    # Set up the tokenizer for targets\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = split_datasets.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    remove_columns=split_datasets[\"train\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForSeq2Seq\n",
    "\n",
    "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['attention_mask', 'input_ids', 'labels', 'decoder_input_ids'])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(1, 3)])\n",
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  577,  5891,     2,  3184,    16,  2542,     5,  1710,     0,  -100,\n",
       "          -100,  -100,  -100,  -100,  -100,  -100],\n",
       "        [ 1211,     3,    49,  9409,  1211,     3, 29140,   817,  3124,   817,\n",
       "           550,  7032,  5821,  7907, 12649,     0]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[59513,   577,  5891,     2,  3184,    16,  2542,     5,  1710,     0,\n",
       "         59513, 59513, 59513, 59513, 59513, 59513],\n",
       "        [59513,  1211,     3,    49,  9409,  1211,     3, 29140,   817,  3124,\n",
       "           817,   550,  7032,  5821,  7907, 12649]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"decoder_input_ids\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[577, 5891, 2, 3184, 16, 2542, 5, 1710, 0]\n",
       "[1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649, 0]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(1, 3):\n",
    "    print(tokenized_datasets[\"train\"][i][\"labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sacrebleu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"sacrebleu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 46.750469682990165,\n",
       " 'counts': [11, 6, 4, 3],\n",
       " 'totals': [12, 11, 10, 9],\n",
       " 'precisions': [91.67, 54.54, 40.0, 33.33],\n",
       " 'bp': 0.9200444146293233,\n",
       " 'sys_len': 12,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\n",
    "    \"This plugin lets you translate web pages between several languages automatically.\"\n",
    "]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 1.683602693167689,\n",
       " 'counts': [1, 0, 0, 0],\n",
       " 'totals': [4, 3, 2, 1],\n",
       " 'precisions': [25.0, 16.67, 12.5, 12.5],\n",
       " 'bp': 0.10539922456186433,\n",
       " 'sys_len': 4,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\"This This This This\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 0.0,\n",
       " 'counts': [2, 1, 0, 0],\n",
       " 'totals': [2, 1, 0, 0],\n",
       " 'precisions': [100.0, 100.0, 0.0, 0.0],\n",
       " 'bp': 0.004086771438464067,\n",
       " 'sys_len': 2,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\"This plugin\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_preds):\n",
    "    preds, labels = eval_preds\n",
    "    # In case the model returns more than the prediction logits\n",
    "    if isinstance(preds, tuple):\n",
    "        preds = preds[0]\n",
    "\n",
    "    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
    "\n",
    "    # Replace -100s in the labels as we can't decode them\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    # Some simple post-processing\n",
    "    decoded_preds = [pred.strip() for pred in decoded_preds]\n",
    "    decoded_labels = [[label.strip()] for label in decoded_labels]\n",
    "\n",
    "    result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
    "    return {\"bleu\": result[\"score\"]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "args = Seq2SeqTrainingArguments(\n",
    "    f\"marian-finetuned-kde4-en-to-fr\",\n",
    "    evaluation_strategy=\"no\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=64,\n",
    "    weight_decay=0.01,\n",
    "    save_total_limit=3,\n",
    "    num_train_epochs=3,\n",
    "    predict_with_generate=True,\n",
    "    fp16=True,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 1.6964408159255981,\n",
       " 'eval_bleu': 39.26865061007616,\n",
       " 'eval_runtime': 965.8884,\n",
       " 'eval_samples_per_second': 21.76,\n",
       " 'eval_steps_per_second': 0.341}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(max_length=max_target_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.8558505773544312,\n",
       " 'eval_bleu': 52.94161337775576,\n",
       " 'eval_runtime': 714.2576,\n",
       " 'eval_samples_per_second': 29.426,\n",
       " 'eval_steps_per_second': 0.461,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(max_length=max_target_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/sgugger/marian-finetuned-kde4-en-to-fr/commit/3601d621e3baae2bc63d3311452535f8f58f6ef3'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.push_to_hub(tags=\"translation\", commit_message=\"Training complete\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "train_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"train\"],\n",
    "    shuffle=True,\n",
    "    collate_fn=data_collator,\n",
    "    batch_size=8,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sgugger/marian-finetuned-kde4-en-to-fr-accelerate'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import Repository, get_full_repo_name\n",
    "\n",
    "model_name = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def postprocess(predictions, labels):\n",
    "    predictions = predictions.cpu().numpy()\n",
    "    labels = labels.cpu().numpy()\n",
    "\n",
    "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
    "\n",
    "    # Replace -100 in the labels as we can't decode them.\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    # Some simple post-processing\n",
    "    decoded_preds = [pred.strip() for pred in decoded_preds]\n",
    "    decoded_labels = [[label.strip()] for label in decoded_labels]\n",
    "    return decoded_preds, decoded_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epoch 0, BLEU score: 53.47\n",
       "epoch 1, BLEU score: 54.24\n",
       "epoch 2, BLEU score: 54.44"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Training\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    for batch in tqdm(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            generated_tokens = accelerator.unwrap_model(model).generate(\n",
    "                batch[\"input_ids\"],\n",
    "                attention_mask=batch[\"attention_mask\"],\n",
    "                max_length=128,\n",
    "            )\n",
    "        labels = batch[\"labels\"]\n",
    "\n",
    "        # Necessary to pad predictions and labels for being gathered\n",
    "        generated_tokens = accelerator.pad_across_processes(\n",
    "            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id\n",
    "        )\n",
    "        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n",
    "\n",
    "        predictions_gathered = accelerator.gather(generated_tokens)\n",
    "        labels_gathered = accelerator.gather(labels)\n",
    "\n",
    "        decoded_preds, decoded_labels = postprocess(predictions_gathered, labels_gathered)\n",
    "        metric.add_batch(predictions=decoded_preds, references=decoded_labels)\n",
    "\n",
    "    results = metric.compute()\n",
    "    print(f\"epoch {epoch}, BLEU score: {results['score']:.2f}\")\n",
    "\n",
    "    # Save and upload\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': 'Par défaut, développer les fils de discussion'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Replace this with your own checkpoint\n",
    "model_checkpoint = \"huggingface-course/marian-finetuned-kde4-en-to-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': \"Impossible d'importer %1 en utilisant le module externe d'importation OFX. Ce fichier n'est pas le bon format.\"}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "翻訳 (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
